[강화학습] Simple Bandit Algorithm Implementation

Reinforcement Learning: An Introduction 2.4절 중 Simple Bandit Algorithm psuedo code 구현

Reinforcement Learning
Author

hoyeon

Published

March 22, 2023

Import

import numpy as np
from scipy.stats import bernoulli,norm
import random

Implementation

def argmax(q_func):
    q_max = max(q_func.values())
    max_actions = []
    for action in q_func.keys():
        if q_func[action] == q_max:
            max_actions.append(action)
    argmax_q = random.sample(max_actions,1)
    return argmax_q[0]
def bandit(q_star):
    #각 bandit에 해당하는 정규분포에서 샘플하나 가져오기
    sampled_reward = norm.rvs(loc=q_star,scale=1,size=1)
    return sampled_reward[0]
def simple_bandit(num_actions, #action의 갯수
                  epsilon,     #greedy하게 안 움직일 확률
                  q_stars, #각 action(k개의 bandit)에 대한 reward의 (정규)분포의 mean
                  terminate_cond
                  ):
    if num_actions != len(q_stars):
        print("action의 갯수와 bandit_mean의 갯수는 같아야 함")
        return 
    Q = {}
    n = {}

    #Initialize,for a=1 to k
    for i in range(1,num_actions+1):
        Q[i] = 0
        n[i] = 0
    
    #set probability,reward distribution(bandit)
    max_prob = 1-epsilon

    #epsilon-greedy action
    while True:

        #epsilon-greedy action
        #1.greedy action? or random action?
        greedy_or_random = bernoulli.rvs(p=max_prob,size=1)
        #2.select action
        if greedy_or_random == 1: 
            action = argmax(Q)
        else:
            action = random.sample(Q.keys(),1)[0]

        #sampling reward from gaussian(mean = q_star,variance = 1)
        q_star = q_stars[action-1]
        reward = bandit(q_star) #sampling

        #incremental Q-update
        n[action] +=1
        Q[action] = Q[action] + (1/n[action]) * (reward - Q[action])

        #Terminating
        if np.max(q_stars - np.array(list(Q.values()))) <= terminate_cond:
            break
    return n,Q
num_actions=10
q_stars = norm.rvs(loc=0,scale=1,size=num_actions)
num_action,Q_estimated = simple_bandit(num_actions=num_actions,epsilon=0.1,q_stars=q_stars,terminate_cond=0.01)
for i in range(1,num_actions+1):
    print("action :",i)    
    print("taken_num",num_action[i])
    print("q_star :",q_stars[i-1])
    print("Q_estimated",Q_estimated[i]) 
    print("=============================================")
C:\Users\22668\AppData\Local\Temp\ipykernel_5296\2459308436.py:30: DeprecationWarning: Sampling from a set deprecated
since Python 3.9 and will be removed in a subsequent version.
  action = random.sample(Q.keys(),1)[0]
action : 1
taken_num 7710
q_star : 1.0622402967553395
Q_estimated 1.0780006731009761
=============================================
action : 2
taken_num 7750
q_star : -0.04393405775378968
Q_estimated -0.0464016778867334
=============================================
action : 3
taken_num 7688
q_star : -0.2977471187595547
Q_estimated -0.28941229886239545
=============================================
action : 4
taken_num 7766
q_star : -0.6942140347822109
Q_estimated -0.6957038264269383
=============================================
action : 5
taken_num 702767
q_star : 1.114864011478195
Q_estimated 1.1163009140188438
=============================================
action : 6
taken_num 7695
q_star : -2.0791842661153535
Q_estimated -2.0891168451782134
=============================================
action : 7
taken_num 7704
q_star : -1.204190113360393
Q_estimated -1.2086679112146015
=============================================
action : 8
taken_num 7754
q_star : 0.9093721477182354
Q_estimated 0.908951189305117
=============================================
action : 9
taken_num 7658
q_star : 0.0790709385366291
Q_estimated 0.08704136534792037
=============================================
action : 10
taken_num 7736
q_star : -0.06424396205536466
Q_estimated -0.05355515687875377
=============================================